""" Parses PDB id's desired chains, and creates new PDB structures. """
from Bio.PDB import PDBParser, Select
from Bio.PDB.PDBIO import PDBIO
from pathlib import Path
import pandas as pd 
import re
import os


class ABSelect(Select):
    def __init__(self, chains):
        super().__init__()
        self.chains = chains

    def accept_chain(self, chain):
        if chain.get_id() in self.chains:
            return True
        else:
            return False


def extract_chains_from_pdb(pdb_path, selected_chains, save_path):
    """
    selected_chains: str
    """
    name = str(Path(pdb_path).stem)
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure(name, pdb_path)
    io = PDBIO()
    io.set_structure(structure)
    selected = ABSelect(selected_chains)
    if type(selected_chains) == list:
        save_pdb_path = f"{save_path}/{name[3:]}_{''.join(selected_chains)}.pdb"
    else:
        save_pdb_path = f"{save_path}/{name[3:]}_{selected_chains}.pdb"
    io.save(save_pdb_path, selected)
    return save_pdb_path


def extract_pdb_chains(pdb_dir, data_path, out_dir):
    # read test file
    data = pd.read_csv(data_path, index_col=0)
    for i, row in data.iterrows():
        pdb_id = row.pdb_id
        chain_id = re.sub(',', '', row.chain_id)
        pdb_path = pdb_dir + pdb_id[1:3] + '/pdb' + pdb_id + '.pdb'
        extract_chains_from_pdb(pdb_path, chain_id, save_path=out_dir)


def generate_gt_pdb_txt(gt_pdb_dir):
    pdb_paths = []
    i =0
    for path, _, file_list in os.walk(gt_pdb_dir):
        for f in file_list:
            file_path = os.path.join(path, f)
            pdb_paths.append(file_path)
            i+=1
    txt_save_path = gt_pdb_dir+'.txt'
    print(i)
    with open(txt_save_path, "w") as f:
        f.write('\n'.join(pdb_paths))

    